{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers, optimizers, datasets\n",
    "from tensorflow.keras.layers import Dense, Dropout, Flatten\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}\n",
    "\n",
    "def mnist_dataset():\n",
    "    (x, y), (x_test, y_test) = datasets.mnist.load_data()\n",
    "    x = x.reshape(x.shape[0], 28, 28,1)\n",
    "    x_test = x_test.reshape(x_test.shape[0], 28, 28,1)\n",
    "    ds = tf.data.Dataset.from_tensor_slices((x, y))\n",
    "    ds = ds.map(prepare_mnist_features_and_labels)\n",
    "    ds = ds.take(20000).shuffle(20000).batch(100)\n",
    "    \n",
    "    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
    "    test_ds = test_ds.map(prepare_mnist_features_and_labels)\n",
    "    test_ds = test_ds.take(20000).shuffle(20000).batch(20000)\n",
    "    return ds, test_ds\n",
    "\n",
    "def prepare_mnist_features_and_labels(x, y):\n",
    "    x = tf.cast(x, tf.float32) / 255.0\n",
    "    y = tf.cast(y, tf.int64)\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3136"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "7*7*64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Sequential([\n",
    "    Conv2D(32, (5, 5), activation='relu', padding='same'),\n",
    "    MaxPooling2D(pool_size=2, strides=2),\n",
    "    Conv2D(64, (5, 5), activation='relu', padding='same'),\n",
    "    MaxPooling2D(pool_size=2, strides=2),\n",
    "    Flatten(), #N*7*7*64 =>N*3136\n",
    "    layers.Dense(128, activation='tanh'), #N*128\n",
    "    layers.Dense(10, activation='softmax')]) #N*10\n",
    "optimizer = optimizers.Adam(0.0001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 编译， fit以及evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0920 16:06:11.339046 4318938560 deprecation.py:323] From /Users/jerrik/a3/envs/tf2.0/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200/200 [==============================] - 31s 154ms/step - loss: 0.9428 - accuracy: 0.6292\n",
      "Epoch 2/5\n",
      "200/200 [==============================] - 28s 139ms/step - loss: 0.2700 - accuracy: 0.9200\n",
      "Epoch 3/5\n",
      "200/200 [==============================] - 34s 169ms/step - loss: 0.1815 - accuracy: 0.9445\n",
      "Epoch 4/5\n",
      "200/200 [==============================] - 31s 154ms/step - loss: 0.1368 - accuracy: 0.9585\n",
      "Epoch 5/5\n",
      "200/200 [==============================] - 32s 159ms/step - loss: 0.1099 - accuracy: 0.9680\n",
      "1/1 [==============================] - 7s 7s/step - loss: 0.0984 - accuracy: 0.9710\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.09840553253889084, 0.971]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.compile(optimizer=optimizer,\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "train_ds, test_ds = mnist_dataset()\n",
    "model.fit(train_ds, epochs=5)\n",
    "model.evaluate(test_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
